package top.shuai7boy.trafficTemp.areaRoadFlow

import java.util

import com.alibaba.fastjson.JSONObject
import org.apache.spark.SparkConf
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
import org.apache.spark.api.java.function.{Function, PairFunction}
import org.apache.spark.sql.{Dataset, Row, RowFactory, SaveMode, SparkSession}
import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
import vip.shuai7boy.trafficTemp.areaRoadFlow.{ConcatStringStringUDF, GroupConcatDistinctUDAF, RandomPrefixUDF, RemoveRandomPrefixUDF}
import vip.shuai7boy.trafficTemp.conf.ConfigurationManager
import vip.shuai7boy.trafficTemp.constant.Constants
import vip.shuai7boy.trafficTemp.dao.ITaskDAO
import vip.shuai7boy.trafficTemp.dao.factory.DAOFactory
import vip.shuai7boy.trafficTemp.domain.Task
import vip.shuai7boy.trafficTemp.util.{JSONObjectTemp, ParamUtils}
import vip.spark.spark.test.MockData

/**
 * 计算出一个区域top3道路流量
 * 每一个区域车流量最多的3条道路，每条道路有多个卡扣
 * <p>
 * 这个一个分组topN 利用Spark SQL分组topN。
 */
object AreaTop3RoadFlowAnalyze {
  def main(args: Array[String]): Unit = {
    /**
     * 判断程序是否在本地运行
     */
    var sc: JavaSparkContext = null
    var spark: SparkSession = null
    val onLocal: Boolean = ConfigurationManager.getBoolean(Constants.SPARK_LOCAL)
    if (onLocal) { //构建spark运行时环境
      val conf: SparkConf = new SparkConf().setAppName(Constants.SPARK_APP_NAME).setMaster("local")
      sc = new JavaSparkContext(conf)
      spark = SparkSession.builder.getOrCreate
      MockData.mock(sc, spark)
    }
    else {
      System.out.println("++++++++++++++++++++++++++++++++++++++开启hive的支持")
      spark = SparkSession.builder.appName(Constants.SPARK_APP_NAME).config("spark.sql" + ".autoBroadcastJoinThreshold", "1048576000").enableHiveSupport.getOrCreate
      sc = new JavaSparkContext(spark.sparkContext)
      spark.sql("use traffic")
    }
    //注册自定义函数
    spark.udf.register("concat_String_string", new ConcatStringStringUDF, DataTypes.StringType)
    spark.udf.register("random_prefix", new RandomPrefixUDF, DataTypes.StringType)
    spark.udf.register("remove_random_prefix", new RemoveRandomPrefixUDF, DataTypes.StringType)
    spark.udf.register("group_concat_distinct", new GroupConcatDistinctUDAF)
    // 获取命令行传入的taskid，查询对应的任务参数
    val taskDAO: ITaskDAO = DAOFactory.getTaskDAO
    val taskid: Long = ParamUtils.getTaskIdFromArgs(args, Constants.SPARK_LOCAL_TASKID_TOPN_MONITOR_FLOW)
    val task: Task = taskDAO.findTaskById(taskid)
    if (task == null) return
    
    val taskParam: JSONObject =JSONObjectTemp.parseObject(task.getTaskParams)
    /**
     * 获取指定日期内车辆信息1
     * (areaId,row)
     */
    val areaId2DetailInfos: JavaPairRDD[String, Row] = getInfosByDateRDD(spark, taskParam)
    /**
     * 从mysql中获取区域信息2
     * (areaId,areaName)
     */
    val areaId2AreaInfoRDD: JavaPairRDD[String, String] = getAreaId2AreaInfoRDD(spark)

    /**
     * 补全区域信息，添加区域名称  3
     * 生成基础临时表
     * temp_car_flow_base
     *
     */
    generateTempRoadFlowBasicTable(spark, areaId2DetailInfos, areaId2AreaInfoRDD)

    /**
     * 统计各个区域车段流量的临时表 4
     *
     */
    generateTempAreaRoadFlowTable(spark)

    /**
     * 计算每个区域排名前三的道路 5
     */
    getAreaTop3RoadFolwRDD(spark)
    System.out.println("++++++++++++++++++full complete+++++++++++++")
    sc.close()
    spark.close()
  }

  /**
   * 获取每个区域topN路段  5
   *
   * @param spark
   */
  def getAreaTop3RoadFolwRDD(spark: SparkSession): Unit = {
    val sql: String = "" + "SELECT " + "area_name," + "road_id," + "car_count," + "monitor_infos, " + "CASE " + "WHEN car_count > 170 THEN 'A LEVEL' " + "WHEN car_count > 160 AND car_count <= 170 THEN 'B LEVEL' " + "WHEN car_count > 150 AND car_count <= 160 THEN 'C LEVEL' " + "ELSE 'D LEVEL' " + "END flow_level " + "FROM (" + "SELECT " + "area_name," + "road_id," + "car_count," + "monitor_infos," + "row_number() OVER (PARTITION BY area_name ORDER BY car_count DESC) rn " + "FROM tmp_area_road_flow_count " + ") tmp " + "WHERE rn <=3"
    val result: Dataset[Row] = spark.sql(sql)
    System.out.println("--------最终的结果-------")
    result.show()
    //写入hive，要有result这个database库
    spark.sql("use result")
    spark.sql("drop table if exists result.areaTop3Road")
    result.write.mode(SaveMode.Overwrite).saveAsTable("areaTop3Road")
  }

  /**
   * 统计每条道路车流量  4
   *
   * @param spark
   */
  def generateTempAreaRoadFlowTable(spark: SparkSession): Unit = {
    val sql: String = "SELECT " + "area_name," + "road_id," + "count(*) car_count," + "group_concat_distinct(monitor_id) monitor_infos " + "FROM tmp_car_flow_basic " + "GROUP BY area_name,road_id"
    val ds: Dataset[Row] = spark.sql(sql)
    ds.registerTempTable("tmp_area_road_flow_count")
  }

  /**
   * 关联添加区域名称，并将数据注册成临时表temp_car_flow_basic  3
   *
   * @param spark
   * @param areaId2DetailInfos
   * @param areaId2AreaInfoRDD
   */
  private def generateTempRoadFlowBasicTable(spark: SparkSession, areaId2DetailInfos: JavaPairRDD[String, Row], areaId2AreaInfoRDD: JavaPairRDD[String, String]): Unit = {
    val tmpRowRDD: JavaRDD[Row] = areaId2DetailInfos.join(areaId2AreaInfoRDD).map(new Function[Tuple2[String, Tuple2[Row, String]], Row]() {
      @throws[Exception]
      override def call(tuple: Tuple2[String, Tuple2[Row, String]]): Row = {
        val areaId: String = tuple._1
        val carFlowDetailRow: Row = tuple._2._1
        val areaName: String = tuple._2._2
        val roadId: String = carFlowDetailRow.getAs("road_id")
        val monitorId: String = carFlowDetailRow.getAs("monitor_id")
        val car: String = carFlowDetailRow.getAs("car")
        RowFactory.create(areaId, areaName, roadId, monitorId, car)
      }
    })
    val structFields: util.List[StructField] = new util.ArrayList[StructField]
    structFields.add(DataTypes.createStructField("area_id", DataTypes.StringType, true))
    structFields.add(DataTypes.createStructField("area_name", DataTypes.StringType, true))
    structFields.add(DataTypes.createStructField("road_id", DataTypes.StringType, true))
    structFields.add(DataTypes.createStructField("monitor_id", DataTypes.StringType, true))
    structFields.add(DataTypes.createStructField("car", DataTypes.StringType, true))
    val schema: StructType = DataTypes.createStructType(structFields)
    val df: Dataset[Row] = spark.createDataFrame(tmpRowRDD, schema)
    df.createOrReplaceTempView("tmp_car_flow_basic")
  }

  /**
   * 获取区域信息 2
   *
   * @param spark
   * @return
   */
  private def getAreaId2AreaInfoRDD(spark: SparkSession): JavaPairRDD[String, String] = {
    var url: String = null
    var user: String = null
    var password: String = null
    val local: Boolean = ConfigurationManager.getBoolean(Constants.SPARK_LOCAL)
    //获取Mysql数据库的url,user,password信息
    if (local) {
      url = ConfigurationManager.getProperty(Constants.JDBC_URL)
      user = ConfigurationManager.getProperty(Constants.JDBC_USER)
      password = ConfigurationManager.getProperty(Constants.JDBC_PASSWORD)
    }
    else {
      url = ConfigurationManager.getProperty(Constants.JDBC_URL_PROD)
      user = ConfigurationManager.getProperty(Constants.JDBC_USER_PROD)
      password = ConfigurationManager.getProperty(Constants.JDBC_PASSWORD_PROD)
    }
    val options: util.Map[String, String] = new util.HashMap[String, String]
    options.put("url", url)
    options.put("driver", "com.mysql.jdbc.Driver")
    options.put("user", user)
    options.put("password", password)
    options.put("dbtable", "area_info")
    // 通过SQLContext去从MySQL中查询数据
    val areaInfoDF: Dataset[Row] = spark.read.format("jdbc").options(options).load
    System.out.println("------------Mysql数据库中的表area_info数据为------------")
    areaInfoDF.show()
    // 返回RDD
    val areaInfoRDD: JavaRDD[Row] = areaInfoDF.javaRDD
    val areaid2areaInfoRDD: JavaPairRDD[String, String] = areaInfoRDD.mapToPair(new PairFunction[Row, String, String]() {
      @throws[Exception]
      override def call(row: Row): Tuple2[String, String] = {
        val areaid: String = String.valueOf(row.get(0))
        val areaname: String = String.valueOf(row.get(1))
        new Tuple2[String, String](areaid, areaname)
      }
    })
    areaid2areaInfoRDD
  }

  /**
   * 获取指定日期内车辆信息    1*
   *
   * @param spark
   * @param taskParam
   * @return
   */
  private def getInfosByDateRDD(spark: SparkSession, taskParam: JSONObject): JavaPairRDD[String, Row] = {
    val startDate: String = ParamUtils.getParam(taskParam, Constants.PARAM_START_DATE)
    val endDate: String = ParamUtils.getParam(taskParam, Constants.PARAM_END_DATE)
    val sql: String = "SELECT " + "monitor_id," + "car," + "road_id," + "area_id " + "FROM	traffic.monitor_flow_action " + "WHERE date >= '" + startDate + "'" + "AND date <= '" + endDate + "'"
    val df: Dataset[Row] = spark.sql(sql)
    df.javaRDD.mapToPair(new PairFunction[Row, String, Row]() {
      @throws[Exception]
      override def call(row: Row): Tuple2[String, Row] = {
        val areaId: String = row.getAs("area_id")
        new Tuple2[String, Row](areaId, row)
      }
    })
  }
}
